/**
Copyright 2016 Siemens AG.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package com.siemens.industrialbenchmark.dynamics;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;

import org.apache.commons.collections.buffer.CircularFifoBuffer;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.log4j.Logger;

import com.google.common.base.Preconditions;
import com.siemens.industrialbenchmark.datavector.action.ActionAbsolute;
import com.siemens.industrialbenchmark.datavector.action.ActionDelta;
import com.siemens.industrialbenchmark.datavector.action.EffectiveAction;
import com.siemens.industrialbenchmark.datavector.state.MarkovianState;
import com.siemens.industrialbenchmark.datavector.state.MarkovianStateDescription;
import com.siemens.industrialbenchmark.datavector.state.ObservableState;
import com.siemens.industrialbenchmark.datavector.state.ObservableStateDescription;
import com.siemens.industrialbenchmark.dynamics.goldstone.GoldstoneEnvironment;
import com.siemens.industrialbenchmark.externaldrivers.setpointgen.SetPointGenerator;
import com.siemens.industrialbenchmark.properties.PropertiesException;
import com.siemens.industrialbenchmark.properties.PropertiesUtil;
import com.siemens.rl.interfaces.DataVector;
import com.siemens.rl.interfaces.Environment;
import com.siemens.rl.interfaces.ExternalDriver;

/** Basic dynamics of the industrial benchmark. The following steps are contained;
 *  <ul>
 *  	<li> update setpoint: the setpoint influences the state and changes according to a random walk, generated by a {@link SetPointGenerator}</li>
 *  	<li> the influence of the actions is added to the state</li>
 *  	<li> operationalcosts are computed</li>
 *  	<li> operationalcosts are delayed and a convoluted operationalcost level is calculated</li>
 *  </ul>
 *  
 * @author Siegmund Duell, Alexander Hentschel, Michel Tokic
 */
public class IndustrialBenchmarkDynamics implements Environment
{
	/** Variable Names */
	//protected Logger mLogger = Logger.getLogger(IndustrialBenchmarkDynamics.class);
	protected final float STEP_SIZE_VELOCITY;
	protected final float STEP_SIZE_GAIN;

	/** Ring Buffer of fixed size implementing a FIFO queue */
	protected CircularFifoBuffer mOperationalCostsBuffer;

	private float[] mEmConvWeights;
	private boolean convToInit = true;
	
	private GoldstoneEnvironment gsEnvironment;
	private final float maxRequiredStep = (float) Math.sin(15.0f/180.0f*Math.PI);
	private final float gsBound = 1.5f; 
	private final float gsSetPointDependency = 0.02f;

	private enum C {
		DGain, DVelocity, DSetPoint, 
		CostSetPoint, CostGain, CostVelocity, DBase,
		STEP_SIZE_GAIN, STEP_SIZE_VELOCITY
	}

    protected MarkovianState markovState;
    protected MarkovianState mMax; 
    protected MarkovianState mMin;
    protected final Properties mProperties;

	private IndustrialBenchmarkRewardFunction mRewardCore;
    private RandomDataGenerator rda = new RandomDataGenerator(); 
    private long randomSeed = 0;
    private float CRGS;
    
    private List<String> markovStateAdditionalNames;
    private List<ExternalDriver> externalDrivers = new ArrayList<ExternalDriver>(); 
    private final ActionDelta zeroAction = new ActionDelta(0, 0, 0);
	

    /**
     * Constructor with configuration Properties
     * @param aProperties The properties objects
     * @throws PropertiesException
     */
    public IndustrialBenchmarkDynamics(Properties aProperties) throws PropertiesException {
        mProperties = aProperties;
        mRewardCore = new IndustrialBenchmarkRewardFunction(aProperties);
        STEP_SIZE_GAIN = PropertiesUtil.getFloat(mProperties, C.STEP_SIZE_GAIN.name(), true);
        STEP_SIZE_VELOCITY = PropertiesUtil.getFloat(mProperties, C.STEP_SIZE_VELOCITY.name(), true);
		
		externalDrivers.add(new SetPointGenerator(mProperties));

        init();
        step(zeroAction);
    }
    
    /**
     * Constructor with configuration Properties and external driver list
     * @param aProperties The properties objects
     * @param externalDrivers The list containing external drivers
     * @throws PropertiesException
     */
    public IndustrialBenchmarkDynamics(Properties aProperties, List<ExternalDriver> externalDrivers) throws PropertiesException {
    	this(aProperties);
    	
    	this.externalDrivers = externalDrivers;
    	
    	init();
    	step(zeroAction);
    }

	/**
	 * initialize the industrial benchmark
	 * @throws PropertiesException
	 */
	protected void init() throws PropertiesException {

        // configure convolution variables
		CRGS = PropertiesUtil.getFloat(mProperties, "CRGS", true);
        mEmConvWeights = getFloatArray(mProperties.getProperty("ConvArray"));
        markovStateAdditionalNames = new ArrayList <String>();
        mOperationalCostsBuffer = new CircularFifoBuffer(mEmConvWeights.length);
        for (int i = 0; i < mEmConvWeights.length; i++) {
            mOperationalCostsBuffer.add(0.0d); // initialize all operationalcosts with zero
            markovStateAdditionalNames.add("OPERATIONALCOST_" + i); // add operationalcost_lag to list of convoluted markov variables
        }
        markovStateAdditionalNames.addAll(MarkovianStateDescription.getNonConvolutedInternalVariables());
        
        // add variables from external driver
        List<String> extNames = new ArrayList<String>();
        for (ExternalDriver d : this.externalDrivers) {
        	for (String n : d.getState().getKeys()) {
            	if (!extNames.contains(n)) {
            		extNames.add(n);
            	}
        	}
        }
        markovStateAdditionalNames.addAll(extNames);
        //markovStateAdditionalNames.addAll(extDriver.getState().getKeys());       
        
        // instantiate markov state with additional convolution variable names
        markovState = new MarkovianState(markovStateAdditionalNames);
        mMin = new MarkovianState(markovStateAdditionalNames); // lower variable boundaries
        mMax = new MarkovianState(markovStateAdditionalNames); // upper variable boundaries
        
    	// extract variable boundings + initial values from Properties 
        for (String v : this.markovState.getKeys()) {
            float init = PropertiesUtil.getFloat(mProperties, v + "_INIT", 0);
            float max = PropertiesUtil.getFloat(mProperties, v + "_MAX", Float.MAX_VALUE);
            float min = PropertiesUtil.getFloat(mProperties, v + "_MIN", -Float.MAX_VALUE);
            Preconditions.checkArgument(max > min,  "variable=%s: max=%s must be > than min=%s", v, max, min);
            Preconditions.checkArgument(init >= min && init <= max,  "variable=%s: init=%s must be between min=%s and max=%s", v, init, min, max);
            mMax.setValue(v, max);
            mMin.setValue(v, min);
            markovState.setValue(v, init);
        }

        // seed all random number generators for allowing to re-conduct the experiment 
        randomSeed = PropertiesUtil.getLong(mProperties, "SEED", System.currentTimeMillis());
        //mLogger.debug("init seed: " + randomSeed);
        rda.reSeed(randomSeed);
        
        //extDriver.setSeed(rda.nextLong(0, Long.MAX_VALUE));
        for (ExternalDriver d : this.externalDrivers) {
        	d.setSeed(rda.nextLong(0, Long.MAX_VALUE));
        	d.filter(markovState);
        }
            
		this.gsEnvironment = new GoldstoneEnvironment(24, maxRequiredStep, maxRequiredStep/2.0);

		// set all NaN values to 0.0
		for (String key : markovState.getKeys()) {
			if (Double.isNaN(markovState.getValue(key))) {
				markovState.setValue(key, 0.0);				
			}
		}
			
		//for (String key : markovState.getKeys()) {
		//	mLogger.debug(key  + "=" + markovState.getValue(key));
		//}
		//System.exit(-1);
		//mRewardCore.setNormal(rda);
	}

	/**
	 * converts a float array represented as a string (e.g. "0.01, 0.2, 0.9") to a Java float[] 
	 * @param aFloatArrayAsString The float array represented as a String
	 * @return The float[] array
	 */
	private float[] getFloatArray(String aFloatArrayAsString) {
		// remove all whitespace
		String components = aFloatArrayAsString.replaceAll("( |\t|\n)", "");
		String[] split = components.split(",");
		float[] result = new float[split.length];
		for (int i = 0; i < result.length; i++) {
			result[i] = Float.parseFloat(split[i]);
		}
		return result;
	}

    /**
     * Returns the observable components from the markovian state.
     *  
     * @return current state of the industrial benchmark
     */    
    public ObservableState getState() {
    	ObservableState s = new ObservableState();
    	for (String key : s.getKeys()) {
    		s.setValue(key, this.markovState.getValue(key));
    	}
    	return s;
    }


	/**
	 * This function applies an action to the industrial benchmark
	 * @param aAction The industrial benchmark action
	 * @return The successor state
	 * @throws PropertiesException
	 */
    @Override
	public double step(DataVector aAction) {

        // apply randomSeed to PRNGs and external drivers + filter (e.g. setpoint)
    	this.rda.reSeed(randomSeed);
    	for (ExternalDriver d : externalDrivers) {
        	d.setSeed(rda.nextLong(0, Long.MAX_VALUE));    		
        	d.filter(this.markovState);
    	}

		// add actions to state:
		addAction((ActionDelta) aAction);

		try {
	        // update spiking dynamics
	        updateFatigue();

			// updated current operationalcost
			updateCurrentOperationalCost();
			
		} catch (PropertiesException e) {
			e.printStackTrace();
		}

		// update convoluted operationalcosts
		updateOperationalCostCovolution();
		
		// update gs
		updateGS();
	
		updateOperationalCosts();
		
        		
        // update reward
        mRewardCore.calcReward(markovState);

        // set random seed for next iteration
        this.randomSeed = rda.nextLong(0, Long.MAX_VALUE);
        this.markovState.setValue(MarkovianStateDescription.RandomSeed, Double.longBitsToDouble(this.randomSeed));
                
        //return observableState;        
        return this.markovState.getValue(ObservableStateDescription.RewardTotal); 
	}
	
	private void updateOperationalCosts() {
		
		
		double rGS = markovState.getValue(MarkovianStateDescription.MisCalibration);

	    // set new OperationalCosts
	    double eNewHidden = (markovState.getValue(MarkovianStateDescription.OperationalCostsConv) - (CRGS * (rGS - 1.0)));
	 	double operationalcosts = eNewHidden - rda.nextGaussian(0,  1) * (1+0.005*eNewHidden);
	 	
	 	markovState.setValue(MarkovianStateDescription.Consumption, operationalcosts);
	}

	private void addAction(ActionDelta aAction) {
		
	  	double velocityMax = mMax.getValue(MarkovianStateDescription.Action_Velocity);
	  	double velocityMin = mMin.getValue(MarkovianStateDescription.Action_Velocity);
	  	double velocity = Math.min(velocityMax, Math.max(velocityMin, markovState.getValue(MarkovianStateDescription.Action_Velocity) + aAction.getDeltaVelocity() * STEP_SIZE_VELOCITY));
	  	if(aAction instanceof ActionAbsolute){
	  		double velocityToSet = ((ActionAbsolute)aAction).getVelocity();
	  		double diff = velocityToSet - markovState.getValue(MarkovianStateDescription.Action_Velocity);
	  		if(diff>STEP_SIZE_VELOCITY){
	  			diff = STEP_SIZE_VELOCITY;
	  		}else if(diff<-STEP_SIZE_VELOCITY){
	  			diff = -STEP_SIZE_VELOCITY;
	  		}
	  		velocity = Math.min(velocityMax, Math.max(velocityMin, markovState.getValue(MarkovianStateDescription.Action_Velocity) + diff));
	  	}
	  	
	  	double gainMax = mMax.getValue(MarkovianStateDescription.Action_Gain);
	  	double gainMin = mMin.getValue(MarkovianStateDescription.Action_Gain);
	  	double gain = Math.min(gainMax, Math.max(gainMin, markovState.getValue(MarkovianStateDescription.Action_Gain) + aAction.getDeltaGain() * STEP_SIZE_GAIN));
	  	if(aAction instanceof ActionAbsolute){
	  		double gainToSet = ((ActionAbsolute)aAction).getGain();
	  		double diff = gainToSet - markovState.getValue(MarkovianStateDescription.Action_Gain);
	  		if(diff>STEP_SIZE_GAIN){
	  			diff = STEP_SIZE_GAIN;
	  		}else if(diff<-STEP_SIZE_GAIN){
	  			diff = -STEP_SIZE_GAIN;
	  		}
	  		gain = Math.min(gainMax, Math.max(gainMin, markovState.getValue(MarkovianStateDescription.Action_Gain) + diff));
	  	}
	  	
		// beide: 10 = 2*1.5 + 0.07*100
	  	final double gsScale = 2.0f*gsBound + 100.0f*gsSetPointDependency;
	  	double shift = (float) Math.min(100.0f, Math.max(0.0f, markovState.getValue(MarkovianStateDescription.Action_Shift) + aAction.getDeltaShift()*(maxRequiredStep/0.9f)*100.0f/gsScale));
	  	if(aAction instanceof ActionAbsolute){
	  		double shiftToSet = ((ActionAbsolute)aAction).getShift();
	  		double diff = shiftToSet - markovState.getValue(MarkovianStateDescription.Action_Shift);
	  		if(diff>((maxRequiredStep/0.9f)*100.0f/gsScale)){
	  			diff = ((maxRequiredStep/0.9f)*100.0f/gsScale);
	  		}else if(diff<-((maxRequiredStep/0.9f)*100.0f/gsScale)){
	  			diff = -((maxRequiredStep/0.9f)*100.0f/gsScale);
	  		}
	  		shift = (float) Math.min(100.0f, Math.max(0.0f, markovState.getValue(MarkovianStateDescription.Action_Shift) + diff));
	  	}
	  	double hiddenShift = (float) Math.min(gsBound, Math.max(-gsBound, (gsScale*shift/100.0f - gsSetPointDependency*markovState.getValue(MarkovianStateDescription.SetPoint) - gsBound)));
	  	
	  	markovState.setValue(MarkovianStateDescription.Action_Velocity, velocity);
	  	markovState.setValue(MarkovianStateDescription.Action_Gain, gain);
	  	markovState.setValue(MarkovianStateDescription.Action_Shift, shift);
		markovState.setValue(MarkovianStateDescription.EffectiveShift, hiddenShift);
	}

	/**
	 * updates the spiking fatigue dynamics
	 * @throws PropertiesException 
	 */
	private void updateFatigue() throws PropertiesException {
		final float expLambda = 0.1f;
		final float actionTolerance = 0.05f;
		final float fatigueAmplification = 1.1f;   
		final float fatigueAmplificationMax = 5.0f;
		final float fatigueAmplificationStart = 1.2f;

		// action
		double velocity = markovState.getValue(MarkovianStateDescription.Action_Velocity);
		double gain = markovState.getValue(MarkovianStateDescription.Action_Gain);
		double setpoint = markovState.getValue(MarkovianStateDescription.SetPoint);

		// hidden state variables for fatigue
		double hiddenStateVelocity = markovState.getValue(MarkovianStateDescription.FatigueLatent1); 
		double hiddenStateGain = markovState.getValue(MarkovianStateDescription.FatigueLatent2);

        EffectiveAction effAction = new EffectiveAction (new ActionAbsolute(velocity, gain, 0.0, this.mProperties), setpoint);
        double  effActionVelocity = effAction.getEffectiveVelocity();
        double  effActionGain = effAction.getEffectiveGain();

        // base noise
        double noiseGain = 2.0 * (1.0/(1.0+Math.exp(-rda.nextExponential(expLambda))) - 0.5);
        double noiseVelocity = 2.0 * (1.0/(1.0+Math.exp(-rda.nextExponential(expLambda))) - 0.5);

        // add spikes
        // keep error within range of [0.001, 0.999] because otherwise Binomial.staticNextInt() will fail.
        noiseGain += (1-noiseGain) * rda.nextUniform(0,1) * rda.nextBinomial(1, Math.min(Math.max(0.001, effActionGain), 0.999)) * effActionGain;
        noiseVelocity += (1-noiseVelocity) * rda.nextUniform(0,1) * rda.nextBinomial(1, Math.min(Math.max(0.001, effActionVelocity), 0.999)) * effActionVelocity;
        
        // compute internal dynamics
        if (effActionVelocity <= actionTolerance) {
        	hiddenStateVelocity = effActionVelocity;
        } else if (hiddenStateGain >= fatigueAmplificationStart) {
        	hiddenStateGain = Math.min(fatigueAmplificationMax,  hiddenStateGain*fatigueAmplification);
        } else {
        	hiddenStateGain = (hiddenStateGain*0.9f) + ((float)noiseGain/3.0f);
        } 
                
        if (effActionGain <= actionTolerance) {
        	hiddenStateGain = effActionGain;
        } else if (hiddenStateVelocity >= fatigueAmplificationStart) {
        	hiddenStateVelocity = Math.min(fatigueAmplificationMax,  hiddenStateVelocity*fatigueAmplification);
        } else {
        	hiddenStateVelocity = (hiddenStateVelocity*0.9f) + ((float)noiseVelocity/3.0f);
        }
        
        
		double alpha = 0.0f;
        if (Math.max(hiddenStateVelocity, hiddenStateGain) == fatigueAmplificationMax) {
        	// bad noise in case fatigueAmplificationMax is reached
        	alpha = 1.0 / (1.0+Math.exp(-rda.nextGaussian(2.4,0.4)));
        } else {
        	alpha = Math.max(noiseGain,  noiseVelocity);
        }
        
        final float cDGain = getConst(C.DGain);
        final float cDVelocity = getConst(C.DVelocity);
        final float cDSetPoint = getConst(C.DSetPoint);
        final float cDynBase = getConst(C.DBase);
           
        double fb = ((cDynBase / ((cDVelocity * velocity) + cDSetPoint)) - cDGain * gain*gain);
        if(fb<0) fb=0;
        double f = ((2.f*alpha+1.0) * fb )/ 3.f;

        markovState.setValue(MarkovianStateDescription.Fatigue, f);  
        markovState.setValue(MarkovianStateDescription.FatigueBase, fb);  

        // hidden state variables for fatigue
    	markovState.setValue(MarkovianStateDescription.FatigueLatent1, hiddenStateVelocity); 
    	markovState.setValue(MarkovianStateDescription.FatigueLatent2, hiddenStateGain);
    	markovState.setValue(MarkovianStateDescription.EffectiveActionVelocityAlpha, effActionVelocity); 
    	markovState.setValue(MarkovianStateDescription.EffectiveActionGainBeta, effActionGain);
	}

	
	private void updateCurrentOperationalCost() throws PropertiesException {

	    final float cCostSetPoint = getConst(C.CostSetPoint);
	    final float cCostGain = getConst(C.CostGain);
	    final float cCostVelocity = getConst(C.CostVelocity);
	    double setpoint = markovState.getValue(MarkovianStateDescription.SetPoint);
	    double gain = markovState.getValue(MarkovianStateDescription.Action_Gain);
	    double velocity = markovState.getValue(MarkovianStateDescription.Action_Velocity);
	    double costs = cCostSetPoint * setpoint + cCostGain * gain + cCostVelocity * velocity;
		
	    double operationalcosts = (float) Math.exp(costs / 100.);
	    markovState.setValue(MarkovianStateDescription.CurrentOperationalCost, operationalcosts);
	    mOperationalCostsBuffer.add(operationalcosts);
	    
	    if(convToInit){
	    	for(int i=1; i<mOperationalCostsBuffer.size(); i++){
	    	    mOperationalCostsBuffer.add(operationalcosts);
	    	}
	    	convToInit = false;
	    }
	}
	
	private void updateOperationalCostCovolution() {
		double aggregatedOperationalCosts = 0;
	    int i = 0;
	    Iterator<?> iterator = mOperationalCostsBuffer.iterator();
	    while(iterator.hasNext()) {
	    	double operationalcost = (Double)iterator.next();   	
	    	aggregatedOperationalCosts += mEmConvWeights[i] * operationalcost;
	    	markovState.setValue("OPERATIONALCOST_"+i, operationalcost);
	    	i += 1;
	    }
	    markovState.setValue(MarkovianStateDescription.OperationalCostsConv, aggregatedOperationalCosts);	 		
	}
	
	private void updateGS() {
		gsEnvironment.setControlPosition(markovState.getValue(MarkovianStateDescription.EffectiveShift));
		markovState.setValue(MarkovianStateDescription.MisCalibration, (float) gsEnvironment.reward());
		markovState.setValue(MarkovianStateDescription.MisCalibrationDomain, gsEnvironment.getDomain());
		markovState.setValue(MarkovianStateDescription.MisCalibrationSystemResponse, gsEnvironment.getSystemResponse());
		markovState.setValue(MarkovianStateDescription.MisCalibrationPhiIdx, gsEnvironment.getPhiIdx());
	}

	private float getConst(C aConst) throws PropertiesException {
		return PropertiesUtil.getFloat(mProperties, aConst.name());
	}

	/** Returns the operationalcosts history length. The current operationalcosts value is part of the history.    
	 *  @return length of the operationalcosts history (including current value) 
	 */
	public int getOperationalCostsHistoryLength() {
		return mOperationalCostsBuffer.size();
	}

    /**
     * Returns a copy of the the current <b>markovian</b> state of the dynamics. 
     *  
     * @return current internal markovian state of the industrial benchmark 
     */    
    public DataVector getInternalMarkovState() {
    	return this.markovState.clone();
    }

    /**
     * Sets the current <b>markovian</b> state of the dynamics. Also the 
     * setpoint generator is set and the operationalcosts are convoluted (+reward recomputed).
     *   
     * @param markovState The markovian state from which the values are copied 
     */   
    public void setInternalMarkovState(DataVector markovState) {
    	
    	// 1) import all key/value pairs
    	for (String key : markovState.getKeys()) {
    		this.markovState.setValue(key, markovState.getValue(key));
    	}

    	// 2) set random number generator states
    	this.randomSeed = Double.doubleToLongBits(this.markovState.getValue(MarkovianStateDescription.RandomSeed));
		
		this.gsEnvironment.setControlPosition(markovState.getValue(MarkovianStateDescription.EffectiveShift));
		this.gsEnvironment.setDomain(markovState.getValue(MarkovianStateDescription.MisCalibrationDomain));
		this.gsEnvironment.setSystemResponse(markovState.getValue(MarkovianStateDescription.MisCalibrationSystemResponse));
		this.gsEnvironment.setPhiIdx(markovState.getValue(MarkovianStateDescription.MisCalibrationPhiIdx));
    	   	
    	// 3) reconstruct operationalcost convolution + reward computation
    	double aggregatedOperationalCosts = 0;
    	double operationalcost = 0;
    	for (int i=0; i<mEmConvWeights.length; i++) {
    		String key = "OPERATIONALCOST_" +i;
    		operationalcost = markovState.getValue(key);
    		aggregatedOperationalCosts += markovState.getValue(key)  * mEmConvWeights[i];
   			mOperationalCostsBuffer.add(operationalcost);
    	}
    	markovState.setValue(MarkovianStateDescription.OperationalCostsConv, aggregatedOperationalCosts);
		//mRewardCore.setNormal(rda);
        mRewardCore.calcReward(markovState);
        
        // 4) set state variables to external driver (e.g. SetPointGenerator parameters)
        for (ExternalDriver d : externalDrivers) {
            d.setConfiguration(markovState);        	
        }
	}


	@Override
	public void reset() {
		try {
			this.init();
		} catch (PropertiesException e) {
			e.printStackTrace();
		}		
	}

	@Override
	public double getReward() {
		return this.markovState.getValue(ObservableStateDescription.RewardTotal);
	}
   
}
